We analyze the random forest model trained in the previous assignment on the FIFA 23 game dataset. The model aims to predict players' wages given their statistics in the game.
We use SHapley Additive exPlanations to compute feature importance.
We also compare the random forest model with the linear model trained in the previous assignment.
We have an observation for which the most important values (due to shap library) are Value in euro and Ball control and another one with the most important values are Overall and Country.
Finding such an observation required quite much searching, because for most of observations the highest impact features are Overall and Value in euro.
The first observation:
The second observation:
These results are very similar for dalex library. Although computing Shapley values with dalex took much more time.
The first observation in dalex:
The second observation in dalex:
We observe that the variable Overall has positive impact on ane observation and negative impact on another observation according to shap library.
Finding such an observation was quite easy.
The first observation:
The second observation:
The results are analogical for dalex library.
The first observation in dalex library:
The second observation in dalex:
The most important features for the random forest model and the linear model usually differ. The tree model mostly focuses on Overall and Value in euro features, while the linear model often focuses on Stats and Position ratings.
The most important features for the tree model and a chosen observation:
The most important features for the linear model and the same observation:
Players A and B are symmetric, so the Shapley values S_a and S_b are equal. Since S_a + S_b + S_c = v(A, B, C) = 100, so S_c = 100 - 2 * S_a.
We have 3! S_a = 20 2! + (60 - 20) 1! 1! + (70 - 60) 1! 1! + (100 - 70) * 2! (we add respectively A on the first position, A on the second position after B, A on the second position after C, A on the third position).
Thus 6 * S_a = 40 + 40 + 10 + 60 = 150 and so S_a = 25.
So we have S_a = 25, S_b = 25, S_c = 100 - 2 * S_a = 50.
# 1. Import libraries
!pip3 install shap
!pip3 install dalex
!pip3 install -q condacolab
import condacolab
condacolab.install()
!conda install -c conda-forge python-kaleido
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly
import kaleido
import pickle
import shap
import dalex as dx
from math import isclose
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
# 2. Load dataset and models from the previous homework (POINT 1)
with open('X_train.pickle', 'rb') as handle:
X_train_load = pickle.load(handle)
with open('y_train.pickle', 'rb') as handle:
y_train_load = pickle.load(handle)
with open('X_test.pickle', 'rb') as handle:
X_test_load = pickle.load(handle)
with open('y_test.pickle', 'rb') as handle:
y_test_load = pickle.load(handle)
with open('tree_model.pickle', 'rb') as handle:
forest_reg_load = pickle.load(handle)
with open('linear_model.pickle', 'rb') as handle:
linear_model_load = pickle.load(handle)
print(X_train_load)
print(y_train_load)
print(X_test_load)
print(y_train_load)
print(forest_reg_load.predict(X_train_load))
print(linear_model_load.predict(X_train_load))
print(forest_reg_load.predict(X_test_load))
print(linear_model_load.predict(X_test_load))
# 1. Observe predictions of two observations (POINT 2)
observations = X_test_load.sample(2, random_state = 1)
predictions = forest_reg_load.predict(observations)
print(observations)
print(predictions)
# 2. Calculate Shapley values for selected observations with shap library (POINT 3)
explainer_shap = shap.TreeExplainer(forest_reg_load)
shap_values = explainer_shap.shap_values(observations)
assert(isclose(np.abs(shap_values.sum(1) + explainer_shap.expected_value - predictions).max(), 0.0, abs_tol = 1e-06))
shap.bar_plot(shap_values[0], feature_names = observations.columns, show = False)
plt.savefig('Point3_shap_im_1.png', dpi=300, bbox_inches='tight')
shap.bar_plot(shap_values[1], feature_names = observations.columns, show = False)
plt.savefig('Point3_shap_im_2.png', dpi=300, bbox_inches='tight')
# 3. Calculate Shapley values for selected observations with dalex library (POINT 3)
explainer_dx = dx.Explainer(forest_reg_load, X_test_load, y_test_load, label='default', verbose = False)
shap_values_dx = [explainer_dx.predict_parts(observations.iloc[i], type = 'shap') for i in range(len(observations))]
# 4. Calculate Shapley values for selected observations with dalex library (POINT 3)
plot_id = 3
for shap_value_dx in shap_values_dx:
fig = shap_value_dx.plot(show = False)
fig.write_image('Point3_dalex_im' + str(plot_id) + '.png')
plot_id += 1
# 5. Searching for two observations in the dataset, such that they have different variables of the highest importance (POINT 4)
new_observations = X_test_load.sample(300, random_state = 25)
grid_shap_values = explainer_shap.shap_values(new_observations)
for grid_shap_value in grid_shap_values:
shap.bar_plot(grid_shap_value, feature_names = new_observations.columns)
# 6. Plotting found observations with different variables of the highest importance (POINT 4)
found_observations = new_observations.iloc[[-20, -14]]
found_shap_values = explainer_shap.shap_values(found_observations)
plot_id = 1
for found_shap_value in found_shap_values:
shap.bar_plot(found_shap_value, feature_names = new_observations.columns, show = False)
plt.savefig('Point4_shap_im' + str(plot_id) + '.png', dpi=300, bbox_inches='tight')
plot_id += 1
# 7. Plotting dalex Shapley values for found observations (POINT 4)
found_shap_values_dx = [explainer_dx.predict_parts(found_observations.iloc[i], type = 'shap') for i in range(len(found_observations))]
plot_id = 1
for found_shap_value_dx in found_shap_values_dx:
fig = found_shap_value_dx.plot(show = False)
fig.write_image('Point4_dalex_im' + str(plot_id) + '.png')
plot_id += 1
# 8. Searching for two observations and a variable in the dataset,
# such that the variable has positive impact on one observation and negative impact on the other one (POINT 5)
new_observations = X_test_load.sample(20, random_state = 1)
grid_shap_values = explainer_shap.shap_values(new_observations)
for grid_shap_value in grid_shap_values:
shap.bar_plot(grid_shap_value, feature_names = new_observations.columns)
# 9. Show selected value and observations (POINT 5)
found_observations = X_test_load.sample(20, random_state = 1)[:2]
found_shap_values = explainer_shap.shap_values(found_observations)
plot_id = 1
for found_shap_value in found_shap_values:
shap.bar_plot(found_shap_value, feature_names = found_observations.columns, show = False)
plt.savefig('Point5_shap_im' + str(plot_id) + '.png', dpi=300, bbox_inches='tight')
plot_id += 1
# 10. Plotting dalex Shapley values for found observations (POINT 5)
found_shap_values_dx = [explainer_dx.predict_parts(found_observations.iloc[i], type = 'shap') for i in range(len(found_observations))]
plot_id = 1
for found_shap_value_dx in found_shap_values_dx:
fig = found_shap_value_dx.plot(show = False)
fig.write_image('Point5_dalex_im' + str(plot_id) + '.png')
plot_id += 1
# 11. Searching for an observation such that SHAP attributions are different
# between the tree model and the linear model (POINT 7)
explainer_linear_shap = shap.Explainer(linear_model_load, X_train_load)
shap_values = explainer_shap.shap_values(observations)
new_observations = X_test_load.sample(20, random_state = 10)
grid_shap_values = explainer_shap.shap_values(new_observations)
grid_linear_shap_values = explainer_linear_shap(new_observations).values
for i in range(len(new_observations)):
grid_shap_value = grid_shap_values[i]
grid_linear_shap_value = grid_linear_shap_values[i]
shap.bar_plot(grid_shap_value, feature_names = new_observations.columns)
shap.bar_plot(grid_linear_shap_value, feature_names = new_observations.columns)
# 12. Show selected observations and plot Shapley values (POINT 5)
new_observations = X_test_load.sample(1, random_state = 10)
grid_shap_values = explainer_shap.shap_values(new_observations)
grid_linear_shap_values = explainer_linear_shap(new_observations).values
plot_id = 1
for i in range(len(new_observations)):
grid_shap_value = grid_shap_values[i]
grid_linear_shap_value = grid_linear_shap_values[i]
shap.bar_plot(grid_shap_value, feature_names = new_observations.columns, show = False)
plt.savefig('Point7_shap_im' + str(plot_id) + '.png', dpi=300, bbox_inches='tight')
plot_id += 1
shap.bar_plot(grid_linear_shap_value, feature_names = new_observations.columns, show = False)
plt.savefig('Point7_shap_im' + str(plot_id) + '.png', dpi=300, bbox_inches='tight')
plot_id += 1